
### Doubly robust estimating equation
## Input: 
##      current estimate of alpha, beta from pmle, gamma from correctly specified glm
## Output: 
##      New estimate of alpha

dr.estimate.onestep <- function(data, cnt, alpha.start,beta, gamma,
                                Hessian=FALSE,y.hat, max.step = 500, thres = 1e-4,
                                l1.hat="holder",SEED = 1, dr.warm=NULL, dr.warm.obj=NULL){
  p.alpha   = length(alpha.true);   p.beta = length(beta.true)
  startpars <- alpha.hat <- c(alpha.start)
  # pars <- startpars <-alpha.hat <- alpha.true # DEBUG
  # beta <- beta.true; gamma <- gamma.true #DEBUG
  l0 = data[,c("l0.1","l0.2")]
  a0 = data[,"a0"];  l1 = data[,"l1"];  a1 = data[,"a1"];   y = data[,"y"]
  nb = nrow(data)  
  alpha.hat <- matrix(alpha.hat[1:p.alpha],ncol=2);
  theta.hat = exp(l0 %*% t(alpha.hat))
  phi123 = exp(l0 %*% t(beta[1:3,]))
  phi45  = expit(l0 %*% t(beta[4:5,]))
  p.l1 = sapply(1:nb, function(i) phi45[i,2 - a0[i]])
  # Nuisance functions
  p.a1 = c(expit(cbind(l1,a0,l0) %*% gamma["a1",]))
  p.a0 = c(expit(l0 %*% gamma["a0",1:2]))
  k = 2 + 2*(1-a0) + (1-l1)
  
  theta.k.hat <- (theta.hat[cbind(seq_len(nrow(theta.hat)), k)]) # old: sapply(1:nb, function(i) theta.hat[i,k[i]])
  theta.k.inv.hat <- 1/theta.k.hat
  theta.k.inv.a1.hat <- (theta.k.inv.hat)^a1
  
  theta..hat <-  theta.hat[,1] # OLD: sapply(1:nb, function(i) theta.hat[i,1])
  theta..inv.hat <- 1/theta..hat
  theta..inv.a0.hat <- (theta..inv.hat)^a0
  
  dr.objective <- function(par.update, parn = "not", pars){
    # pars represents alpha here. beta is beta, gamma is gamma.
    # base.covs <- func.DataGen(c(pars, c(beta), c(gamma)), data,cnt)
    # MyAttach(base.covs)
    # pars <- matrix(pars[1:p.alpha],ncol=2);
    
    if(parn == "alpha1")     pars[1,]   = par.update
    if(parn == "alpha2345")  pars[2:5,] = par.update
    
    theta  = exp(l0 %*% t(pars)); 
    
    theta.k <- (theta[cbind(seq_len(nrow(theta)), k)]) # old: sapply(1:nb, function(i) theta[i,k[i]])
    theta.k.inv <- 1/theta.k
    theta.k.inv.a1 <- (theta.k.inv)^a1
    
    # for (i in 1:nb){theta.k[i] <- theta[i,][k[i]]}
    
    ### m = 1
    u1 <- y * theta.k.inv.a1 
    if (ee.assume.true.alpha){
      E.u1 <- g(l1, a0, l0, y.hat)
    }else{
      E.u1 <- g(l1, a0, l0, p.y.cond, p.a1, 
                theta.k.inv=theta.k.inv.hat, gamma, nb,y.hat)
      # print(1) #DEBUG
    }
    
    # print(u1 - E.u1)
    
    # d1.row.k <- -l0 * a1 * p.y * theta.k.inv.a1
    if (ee.assume.true.alpha){
      d1.row.k <- -l0 * a1 * y.hat.func(l0,a0,l1,a1=vec.0,y.hat)
      E.d1.row.k <- -l0 * g1(l1, a0, l0, y.hat,gamma)
    }else{
      d1.row.k <- -l0 * a1 * y.hat.func(l0,a0,l1,a1,y.hat) * theta.k.inv.a1.hat
      # E.d1.row.k <- -l0 * p.a1 * E.u1
      # E.d1.row.k <- -l0 * g1(l1, a0, l0, p.y.cond, p.a1, theta.k.inv, gamma, nb)
      E.d1.row.k <- -l0 * g1(l1, a0, l0, theta.k.inv=theta.k.inv.hat, gamma, nb,y.hat)
      # print(2) # DEBUG
    }
    (theta.hat[cbind(seq_len(nrow(theta.hat)), k)])
    d1 <- array(rep(0,5*2*nb),dim=c(nb,5, 2))
    # OLD
    # for (i in 1:nb) {
    #   d1[i,k[i],] <- d1.row.k[i,]
    # }
    for (i in 1:dim(d1.row.k)[2]){
      d1[cbind(seq_len(nrow(d1.row.k)),k,rep(i,nb))] <- d1.row.k[,i]
    }
    
    E.d1 <- array(rep(0,5*2*nb),dim=c(nb,5, 2))
    # OLD
    # for (i in 1:nb) {
    #   E.d1[i,k[i],] <- E.d1.row.k[i,]
    # }
    for (i in 1:dim(d1.row.k)[2]){
      E.d1[cbind(seq_len(nrow(d1.row.k)),k,rep(i,nb))] <- E.d1.row.k[,i]
    }
    
    test.u1 <- 0 #DEBUG
    # for (i in 1:nb){test.u1 <- test.u1+((d1 - E.d1) * (u1 - E.u1))[i,,] * cnt[i]}
    # print(test.u1)
    # j <<- j+1
    # vals[[j]] <<- test.u1
    
    ### m = 0
    theta. <- theta[,1] #OLD: sapply(1:nb, function(i) theta[i,1])
    theta..inv <- 1/theta.
    theta..inv.a0 <- (theta..inv)^a0
    
    u0 <- y * theta.k.inv.a1 * theta..inv.a0
    p.a0 <- p.a0.func(l0,gamma)
    E.u0 <- p.a0 * h(a0=vec.1, l0, theta..inv.a0= theta..inv.a0.hat,
                     p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, 
                     gamma, nb,
                     beta, y.hat,l1.hat) +
      (1-p.a0) * h(a0=vec.0, l0, theta..inv.a0= theta..inv.a0.hat,
                   p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, gamma, nb,
                   beta, y.hat,l1.hat)
    # test.u0 <- 0 #DEBUG
    # for (i in 1:nb){test.u0<- test.u0+((u0 - E.u0))[i] * cnt[i] }
    # print(test.u0) #DEBUG
    # j <<- j+1
    # vals[[j]] <<- test.u1
    
    
    d0.row.1 <- f.(a0,l0, theta..inv.a0= theta..inv.a0.hat, 
                   p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat
                   , gamma, nb,
                   beta,y.hat,l1.hat)
    # theta..inv.a0.hat <- exp(exp(l0 %*% t(alpha.hat)))-# DEBUG
    d0.row.k <- fk(a0,l0, theta..inv= theta..inv.hat,
                   p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, gamma, nb,
                   beta,y.hat,l1.hat) # This is a matrix of dim = (4, nb)
    
    # p.a0 / theta..inv.hat + (1-p.a0) #DEBUG
    # E.d0.row.1: Two should be equivalent # Can be sanity check
    E.d0.row.1.test <- p.a0 * f.(a0=vec.1,l0, theta..inv.a0,
                                 p.y.cond, p.a1, theta.k.inv, gamma, nb,
                                 beta, y.hat,l1.hat) +
      (1- p.a0) * f.(a0=vec.0,l0, theta..inv.a0,
                     p.y.cond, p.a1, theta.k.inv, gamma, nb,
                     beta, y.hat,l1.hat)
    E.d0.row.1 <- -l0 * p.a0 * h(a0=a0, l0, 
                                 theta..inv.a0= theta..inv.a0.hat,
                                 p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, 
                                 gamma, nb,
                                 beta, y.hat,l1.hat)
    
    # E.d0.row.k <- p.a0 * fk(a0=vec.1,l0, theta..inv= theta..inv.hat,
    #                         p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, gamma, nb,
    #                         beta, y.hat,l1.hat) +
    #   (1- p.a0) * fk(a0=vec.0,l0, theta..inv= theta..inv.hat,
    #                  p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, gamma, nb,
    #                  beta, y.hat,l1.hat)
    E.d0.row.k <- E.fk(p.a0=p.a0,l0, theta..inv= theta..inv.hat,
                   p.y.cond, p.a1, theta.k.inv=theta.k.inv.hat, gamma, nb,
                   beta,y.hat,l1.hat) # This is a matrix of dim = (4, nb)
    
    
    d0 <- array(rep(0,5*2*nb),dim=c(nb,5, 2))
    for (i in 1:nb) {
      d0[i,1,] <- d0.row.1[i,]
      # d0[i,k[i],] <- d0.row.k[i,]
      # d0[i,2:5,] <- d0.row.k[i,]
    }
    # change 27Jul2021
    d0[,2:5,] <- d0.row.k
    
    E.d0 <- array(rep(0,5*2*nb),dim=c(nb,5, 2))
    for (i in 1:nb) {
      E.d0[i,1,] <- E.d0.row.1[i,]
      # E.d0[i,k[i],] <- E.d0.row.k[i,]
      # E.d0[i,2:5,] <- E.d0.row.k[i,]
    }
    # change 27Jul2021
    E.d0[,2:5,] <- E.d0.row.k
    
    tmp <-  array(rep(0,5*2),dim=c(5, 2))
    tmp.stage1 <-  array(rep(0,5*2),dim=c(5, 2)) # DEBUG
    tmp.d1 <- tmp.d0 <-  array(rep(0,5*2),dim=c(5, 2)) # DEBUG
    tmp.u1 <- tmp.u0 <- 0
    # tmp2 <- matrix(rep(0,32*2),ncol=2) #  DEBUG
    total <- (d1 - E.d1) * (u1 - E.u1) + (d0 - E.d0) * (u0 - E.u0)
    for (i in 1:nb){
      tmp <- tmp + total[i,,] * cnt[i]
      # tmp.stage1 <- tmp.stage1 + ((d1 - E.d1) * (u1 - E.u1))[i,,] * cnt[i]
      # tmp.d1 <- tmp.d1 + ((d1 - E.d1))[i,,] * cnt[i]
      # tmp.u1 <- tmp.u1 + ((u1 - E.u1))[i] * cnt[i] 
      # tmp.d0 <- tmp.d0 + ((d0 - E.d0))[i,,] * cnt[i]
      # tmp.u0 <- tmp.u0 + ((u0 - E.u0))[i] * cnt[i] 
    }

    return(sum(tmp^2))
  }
  # dr.objective(c(startpars))
  # opt <- Optimize(startpars,dr.objective,Hessian)
  # alpha.est <- opt$par
  Optim = function(par.update, fn, parn, pars){
      optim(par.update, fn, parn=parn, pars = pars, 
            control=list(maxit=max.step), method =  "L-BFGS-B",
            lower = -rep(10,10), upper = rep(10,10))
  } 
  
  pars = matrix(startpars[1:p.alpha],ncol=2);
  if (dr.warm == "random"){
    set.seed(SEED)
    pars = matrix(runif(p.alpha, -0.2,0.2),ncol=2);
    # pars = matrix(rep(0,p.alpha),ncol=2);
  }else if (dr.warm == "MLE"| is.null(dr.warm)){
    # By default: use MLE as both nuisance and starting value
    pars = matrix(startpars[1:p.alpha],ncol=2);
  }else{ # custom dr.warm
    pars = dr.warm.obj
  }
  
    Diff = function(x,y) sum((x-y)^2)/sum(x^2+thres)
    diff = thres + 1; step = 0; total.count=0
    while(diff > thres & step < max.step){
        if(step %% 10 == 0) cat("This is the ", step, "th step. Diff = ", diff, "\n")
        step = step + 1
        opt1 = Optim(pars[1,], dr.objective, "alpha1", pars)
        diff1 = Diff(opt1$par,pars[1,])
        pars[1,] = opt1$par
        
        opt2 = Optim(pars[2:5,], dr.objective, "alpha2345", pars)
        diff2 = Diff(opt2$par,pars[2:5,])
        pars[2:5,] = opt2$par
        
        
        diff = max(diff1, diff2)
        total.count =  total.count + opt1$counts + opt2$counts
    }
    print(paste0("Numbers of iterations for DR is: ", total.count))
    # print(vals)
    # load(data.frame(vals=vals), envir = .GlobalEnv)
    alpha.est <- pars
    
    return(alpha.est)
}

dr.estimate.noiterate <- function(data, cnt, alpha.start,beta, gamma,y.hat,l1.hat,
                                  Hessian=FALSE,SEED = 1, dr.warm=NULL, dr.warm.obj=NULL){
    
    alpha.dr <- dr.estimate.onestep(data, cnt, alpha.start,beta, gamma,Hessian=Hessian,y.hat=y.hat,
                                    l1.hat=l1.hat, SEED=SEED, dr.warm=dr.warm, dr.warm.obj=dr.warm.obj)             
    
    if(MESSAGE){
        print(paste("DR One Step: "," Alpha: ",paste(round(alpha.dr,5),collapse=", ")," Beta: ",paste(round(beta.ml,5),collapse=", ")))    
    }
    # abind(alpha.true,beta.true, along=3)
    alpha.dr <- matrix(c(alpha.dr),ncol=2)
    return(alpha.dr)
}
